Skip to content

[#12699][feat] AutoDeploy: Support Piecewise CG for VLMs#12749

Closed
nvchenghaoz wants to merge 30 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/piecewise_update_0402
Closed

[#12699][feat] AutoDeploy: Support Piecewise CG for VLMs#12749
nvchenghaoz wants to merge 30 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/piecewise_update_0402

Conversation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator

@nvchenghaoz nvchenghaoz commented Apr 3, 2026

#12699

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Mixture-of-Experts (MoE) quantization support to improve model efficiency and throughput.
  • Improvements

    • Enhanced Qwen 3.5 MoE model configuration with increased throughput capacity and batch processing limits.
    • Improved distributed computing support for multi-GPU deployments.
    • Optimized model compilation pipeline for better performance across various hardware configurations.
  • Tests

    • Extended test coverage for MoE sharding and quantization scenarios.

taylor-yb-lee and others added 27 commits April 2, 2026 01:27
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
(This was good for Qwen3.5 w/ long input sequence (=> 15000)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
  Fix lm_head not sharded in Qwen3_5MoeForConditionalGeneration export
  The Qwen3_5MoeForConditionalGeneration factory exports only
  model.language_model, leaving lm_head outside the graph.
  This causes lm_head to run unsharded (248320x4096 on every GPU) and prevents gather_logits_before_lm_head from optimizing it.
  Graft lm_head into the exported graph during post_process:
  - Capture lm_head from the parent model in from_autoinferred()
  - Insert auto_deploy.torch_linear_simple + aten.to.dtype nodes wittexplicit names for filter matching
  - Set _lm_head_grafted flag so the parent forward skips redundant
    lm_head during cache init
  - Add "lm_head": "gather" to the manual tp_plan for column-spli sharding capture.

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
  1. Checks position_ids is not None and not has_images and not has_videos _ this is the AD runtime text-only path (every call in your benchmark)
  2. Expands position_ids to 3D (same logic as the original)
  3. Calls self.language_model(inputs_embeds=..., position_ids=...) with just 2 args _ no **kwargs

  This skips all of the following overhead that was running on every forward step:
  - 12 kwargs.get() calls for multimodal metadata extraction
  - has_chunk_mm_layout check with .numel() and .item() calls
  - mrope_delta_cache lookup loop over all kwargs
  - 3D position_ids conditional branching with cu_seqlen tensor ops
  - 10-item kwargs.pop() loop + key-suffix scan
  - **kwargs passthrough to language_model (forces flatten/hash)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
From the commit 93d99f1, Qwen3.5 model was created by Qwen3_5MoeFactory.
However, it exports only the inner text model as a GraphModule, wrapping it in a non-GraphModule wrapper.
This broke piecewise CUDA graph capture.

This commit fixes it by exporting the full model as a single GraphModule.
Also added _init_dynamic_shape_lookup() returns 2D position_ids spec for full model export (the 2D_3D expansion happens inside the traced graph)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…by forward().

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
remove redundant comment

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…wise cuda graph.

Exported only text model (decoder + lm_head) as a graph module and added  named_args preprocessing hook on CachedSequenceInterface to convert input_ids to inputs_embed outside the graph.

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 3, 2026 21:35
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

The PR updates the Qwen3.5 MoE 400B model deployment infrastructure by extending model capacity configuration, restructuring lm_head ownership to the text model, refactoring CUDA graph capture for dynamic dimensions and nested modules, and extending TP-sharding with MoE-specific quantization scale handling.

Changes

Cohort / File(s) Summary
Model Configuration & Architecture
examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml, tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py
Config increases max_num_tokens (8192→16000), max_batch_size (32→256), expands cuda_graph_batch_sizes, adds world_size: 8, and introduces fuse_nvfp4_moe transform. Model code restructures lm_head ownership: Qwen3_5MoeTextModel now contains and computes logits via set_lm_head(); Qwen3_5MoeForCausalLM and Qwen3_5MoeForConditionalGeneration wire lm_head into text model and consume outputs.logits; added embedding accessor methods and factory registration for AutoModelForCausalLMFactory.
CUDA Graph Capture Infrastructure
tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
Major refactoring to support dynamic dimensions and nested modules: added per-input dynamic_dims tracking with auto-detection across batch sizes, replaced batch-dimension slicing with narrow() on detected dims, introduced _static_input_buffers and kwargs tensor pre-allocation for stable addresses, added _reconstruct_output() for structured output unflattening, replaced token-count extraction with batch_info_host computation, and added wrapper→inner GraphModule compilation via pre-hook capture.
Compilation Transform Pipeline
tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
Replaced single-pass root module compilation with hierarchy traversal to identify and separately compile top-level GraphModule instances; added _set_submodule() helper to swap compiled graphs into their dotted paths; compilation now passes full_model=mod for non-root targets; added logging for compilation targets.
Quantization Hook Updates
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Updated NVFP4LinearQuantizationFromConfig.load_hook to prepend prefix to weight_name for correct state_dict key lookup in nested module hierarchies.
TP Sharding & MoE Scale Handling
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Extended TP sharding for quantization by prefixing weight names in scale checkpoint loading; changed weight sharding hook registration from top-level to submodule-level; added _shard_nvfp4_moe_scale() and _tp_shard_moe_scale() helpers for MoE expert scale TP-sharding with CUTLASS-compatible format preservation; updated _insert_sharded_moe() to TP-shard blocked scales for w_up, w_down, w_gate when tp_size > 1.
Graph Utility Enhancements
tensorrt_llm/_torch/auto_deploy/utils/_graph.py
Updated get_lm_head_node() to unwrap column-sharded LM-head graphs by detecting and replacing torch.ops.auto_deploy.trtllm_dist_all_gather wrapper nodes with their inputs.
Test Coverage
tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py
Added comprehensive MoE TP-sharding test coverage: new NVFP4MoEOpModel test fixture combining MoE routing with NVFP4 quantization; added test_moe_tp_shard_bf16 and test_moe_tp_shard_nvfp4 to validate expert weight sharding, all-reduce presence, and quantized scale handling under TP-sharding.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • PR #9200: Modifies TP sharding pipeline in sharding.py with subgraph-aware parameter handling and load-hook registration patterns
  • PR #9459: Extends MoE/TP sharding infrastructure and modifies sharding logic for nested module structures

Suggested labels

AutoDeploy

Suggested reviewers

  • laikhtewari
  • lucaslie
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description only contains a placeholder issue reference (#12699) and a CodeRabbit summary tag with no actual descriptive content. It fails to explain what the changes accomplish, why they were made, what was tested, or how reviewers should evaluate the changes. Complete the description section with a clear explanation of the changes, their purpose, and testing approach. Remove the placeholder and provide substantive information about piecewise CUDA graph VLM support changes.
Docstring Coverage ⚠️ Warning Docstring coverage is 51.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates this PR adds piecewise CUDA graph support for Vision Language Models (VLMs) in AutoDeploy, which aligns with the substantial changes across multiple files including CUDA graph compilation, model architecture modifications, and VLM-specific configurations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py`:
- Around line 226-237: The capture uses the wrong truncated sizes: after copying
each input from args_batched you compute size_i but then build inputs_truncated
using a fixed bs, which records incorrect shapes for dynamic dims; update the
truncation to narrow each self._input_buffers[i] using the corresponding size
from args_batched (the same size_i computed in the copy loop) so
inputs_truncated is created with per-input extents (refer to args_batched,
dynamic_dims, self._input_buffers, inputs_truncated and bs) — e.g., replace the
list comprehension that uses bs with one that uses the per-index size computed
from args_batched.
- Around line 537-543: In _copy_to_static_buffers the code replaces kwargs[key]
with the full pre-allocated buffer (buf), which changes the logical shape;
instead assign a narrowed view of the buffer that matches the source's runtime
size so address stability is preserved but the graph sees the original shape.
Concretely, after copying (buf.narrow(dyn_dim, 0,
src.shape[dyn_dim]).copy_(src)), set kwargs[key] = buf.narrow(dyn_dim, 0,
src.shape[dyn_dim]) (or assign that view to a variable) instead of kwargs[key] =
buf; operate on symbols _copy_to_static_buffers, _static_input_buffers, buf,
dyn_dim, and kwargs.
- Around line 724-730: The fallback path iterates over result and thus corrupts
plain torch.Tensor outputs (returning a tuple of slices); add an early branch
that detects if result is a torch.Tensor and return _narrow(result) (or the
appropriately truncated tensor) before the existing hasattr(result, "to_tuple")
and isinstance(result, abc.Mapping) checks so that tensor outputs keep their
original type; reference the symbols result, _narrow, hasattr(result,
"to_tuple"), and abc.Mapping when making the change.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py`:
- Line 830: The model double-registers the same nn.Linear under two paths
causing state_dict key mismatches; to fix, remove the duplicate registration by
ensuring only one attribute owns the linear (either keep self.lm_head in the
parent OR let the child call set_lm_head but do NOT assign self.lm_head twice);
specifically update Qwen3_5MoeForCausalLM / Qwen3_5MoeForConditionalGeneration
to stop assigning the same module to both self.lm_head and via set_lm_head(),
and add a proper _tied_weights_keys tuple in Qwen3_5MoeForConditionalGeneration
(and verify the one in Qwen3_5MoeForCausalLM) that maps the single top-level
lm_head key (e.g. ("lm_head.weight","lm_head.bias") or the correct pair used in
your code) so load_state_dict() only expects the HF checkpoint keys.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py`:
- Around line 153-160: The loop collecting compile_targets lets child
GraphModules through when the root mod is a GraphModule because seen contains ""
but the current filter (if p) ignores it; update the logic so that when mod
itself is a GraphModule you mark that fact in seen and then treat an
empty-string entry as matching all children. Concretely: when detecting the root
GraphModule (isinstance(mod, GraphModule)) add "" to seen (or otherwise record
the root), and change the membership check from if any(name.startswith(p + ".")
for p in seen if p): to something that treats p == "" as matching (e.g., if
any(p == "" or name.startswith(p + ".") for p in seen):) so child GraphModules
are skipped when the parent/root GraphModule is already scheduled for
compilation (affecting variables compile_targets, seen, mod.named_modules(),
name, submod).

In `@tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py`:
- Around line 1926-1942: The loop currently TP-shards any scale named
"weight_scale" (from FP8EPShardingInfo.scale_names()) treating it like
NVFP4/CUTLASS; to fix, skip TP-sharding for plain FP8 weight_scale by adding a
guard before calling _tp_shard_moe_scale: if s_name == "weight_scale" and the
sharding/layout indicates plain FP8 (not NVF4/CUTLASS), continue (leave
replicated). Update the condition around _BLOCKED_SCALE_NAMES or add an explicit
check using the sharding/layout flag available in the context (e.g., inspect the
EP sharding info or a format enum), ensuring tp_size, scale_names,
_tp_shard_moe_scale and FP8EPShardingInfo.scale_names() are used to decide
whether to call _tp_shard_moe_scale.

In
`@tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py`:
- Around line 1438-1457: The test's loop over gm_transformed.named_parameters()
is filtering on "experts" but NVFP4MoEOpModel registers parameters as
"expert_{i}_w*", so the shard-shape assertions never run; update the condition
in the loop to detect the actual parameter names (e.g., check for "expert_" or
match r"expert_\d+_w" / name.startswith("expert_") and then the specific
"w1"/"w2"/"w3" suffixes) so the branches for w1, w2, w3 execute and validate
dim0/dim1 against intermediate_size // world_size (and packed k logic) for the
NVFP4 TP-sharded weights in gm_transformed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e90b7651-1a5f-4af9-bb7d-07e4c2f0186e

📥 Commits

Reviewing files that changed from the base of the PR and between 7ee9e8b and ce45322.

📒 Files selected for processing (8)
  • examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/_graph.py
  • tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

@CodeRabbit please review the piecewise related code.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

Tip

For best results, initiate chat on the files or code changes.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
cat tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Length of output: 40271


🏁 Script executed:

#!/bin/bash
cat tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py

Length of output: 7180


@nvchenghaoz Here is a detailed review of the piecewise CUDA graph code.


Review: Piecewise CUDA Graph (torch_cudagraph.py + compile_model.py)

🔴 Issues

1. num_tokens not forwarded in PiecewiseCapturedGraph.forward fallback

num_tokens is explicitly extracted from **kwargs in the signature, so it is not present in **kwargs when original_model is called in the fallback branch.

def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any:
    if self.split_gm is not None:
        ...
    return self.original_model(*args, **kwargs)  # num_tokens is silently dropped

If original_model (or any caller down the chain) relies on receiving num_tokens, this will silently break. The fix is straightforward:

return self.original_model(*args, num_tokens=num_tokens, **kwargs)

2. Silent exception swallow in _reconstruct_output

def _reconstruct_output(self, result: Any) -> Any:
    if not isinstance(result, tuple) or self._out_spec is None:
        return result
    try:
        return self._out_spec.unflatten(list(result))
    except Exception:
        return result  # ← exception silently discarded

Any TreeSpec/shape mismatch (e.g. a mismatched out_spec from the monolithic path) will be silently swallowed, returning a raw tuple instead of a structured ModelOutput. This is very hard to debug. At minimum, log a warning:

except Exception as e:
    ad_logger.warning("_reconstruct_output: failed to unflatten output (%s), returning raw tuple", e)
    return result

3. Address-unstable kwargs with no detected dynamic dim are silently skipped

In _allocate_static_input_buffers:

if dyn_dim is not None:
    self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)

If a kwarg tensor has an unstable address across calls but its shape is static (same shape at max_bucket and max_bucket - 1), dyn_dim remains None and no static buffer is allocated. CUDA graph replay will then use the original (unstable) pointer and can crash or silently corrupt state. This should be either:

  • Explicitly handled with torch.empty_like(v1) buffer and a direct full copy (no narrow), or
  • Logged as a warning so it is visible:
else:
    ad_logger.warning(
        "PiecewiseCapturedGraph: kwarg '%s' has unstable address but no dynamic dim found — "
        "not buffered; CUDA graph replay may crash if address changes",
        key,
    )

🟡 Warnings

4. _truncate_output dim-matching heuristic is fragile for square tensors

for d in range(v.ndim):
    if v.shape[d] == bucket:
        return v.narrow(d, 0, num_tokens)

For a tensor with shape [512, 512] and bucket=512, this will truncate on dim=0 (the first match), which may be wrong if the token dimension is dim=1. Consider passing the known token dimension explicitly, or at least adding an assertion that only one dim matches:

matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
if len(matching) > 1:
    ad_logger.warning("_truncate_output: ambiguous token dim for shape %s, bucket=%d", v.shape, bucket)

5. ADPiecewiseRunner.set_current_num_tokens(None) not in try/finally

In DualModeCapturedGraph.forward:

result = self.piecewise(*args, num_tokens=bucket, **kwargs)
ADPiecewiseRunner.set_current_num_tokens(None)   # ← not reached if piecewise raises
if bucket > num_tokens:
    result = self._truncate_output(result, num_tokens, bucket)
return result

If self.piecewise(...) raises, the global num_tokens state is left dirty, which could corrupt subsequent decode calls. Use try/finally:

try:
    result = self.piecewise(*args, num_tokens=bucket, **kwargs)
finally:
    ADPiecewiseRunner.set_current_num_tokens(None)

6. __getattr__ proxy in DualModeCapturedGraph can mask AttributeError on self.monolithic

def __getattr__(self, name: str):
    try:
        return super().__getattr__(name)
    except AttributeError:
        return getattr(self.monolithic.model, name)

If self.monolithic exists but self.monolithic.model raises AttributeError for name, the exception propagates without a clear message. Consider adding a fallback error message:

except AttributeError:
    pass
raise AttributeError(f"'{type(self).__name__}' and its inner model have no attribute '{name}'")

✅ Things that look good

  • compile_model.py GraphModule traversal: The seen set + startswith guard correctly avoids double-compiling nested GMs. Setting full_model=mod for nested targets and None for the root is correct.
  • _capture_inner_kwargs using a pre-hook: The pre-hook fires before the inner module's forward, so tensor references are fresh. The try/finally around handle.remove() correctly avoids hook leaks.
  • CapturedGraph auto-detection of dynamic dims: Using probe_bs = max(1, batch_sizes[0] - 1) sensibly avoids OOB issues. The edge case of batch_sizes[0] == 1 (probe == max) safely defaults to dim=0.
  • out_spec threading: Passing monolithic._out_spec into PiecewiseCapturedGraph after monolithic capture is a clean way to share the tree structure without re-running a warmup.
  • Static buffer detection via data_ptr(): Using data_ptr() comparison across two calls is a robust way to detect address instability, better than static analysis of the graph.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

@CodeRabbit more review please

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

Tip

For best results, initiate chat on the files or code changes.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
cat tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py
echo "====SEPARATOR===="
cat tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
echo "====SEPARATOR===="
cat tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
echo "====SEPARATOR===="
cat tensorrt_llm/_torch/auto_deploy/utils/_graph.py
echo "====SEPARATOR===="
cat tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Length of output: 295700


@nvchenghaoz Thanks — did a deeper pass across the piecewise path and related integration points (VLM wrapper, compile_model traversal, sharding/quant hooks). Below are concrete, targeted follow-ups with proposed patches.

Deeper review and fixes:

  • PiecewiseCapturedGraph.forward fallback drops num_tokens
    • Risk: original_model won’t receive num_tokens in fallback.
    • Fix:
      """
      return self.original_model(*args, num_tokens=num_tokens, **kwargs)
      """
  • _reconstruct_output swallows errors silently
    • Add a warning to aid debugging:
      """
      except Exception as e:
      ad_logger.warning("_reconstruct_output: failed to unflatten output (%s); returning raw tuple", e)
      return result
      """
  • Address-unstable kwargs with static shape are skipped
    • Today: we only allocate a static buffer when a dynamic dim is found.
    • Safer: also allocate when address is unstable but shape is static; copy full tensor:
      """
      else:
      self._static_input_buffers[key] = (torch.empty_like(v1), None)
      """
    • And in copy path: if dyn_dim is None, do a full copy instead of narrow().
  • Output truncation heuristic in _truncate_output can select the wrong dim for square tensors (e.g., [S, S, H])
    • Prefer an explicit token dimension if known; at minimum, warn when multiple dims equal bucket:
      """
      matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
      if len(matching) != 1:
      ad_logger.warning("_truncate_output: ambiguous dim for shape %s, bucket=%d; dims=%s", tuple(v.shape), bucket, matching)
      d = matching[0] if matching else -1
      return v.narrow(d, 0, num_tokens) if d >= 0 else v
      """
    • Even better: derive token dim once by comparing shapes at capture-time between bucket and probe and store it.
  • DualModeCapturedGraph.forward leaves ADPiecewiseRunner num_tokens set on exception
    • Wrap the piecewise call in try/finally to always clear:
      """
      try:
      result = self.piecewise(*args, num_tokens=bucket, **kwargs)
      finally:
      ADPiecewiseRunner.set_current_num_tokens(None)
      """
  • getattr proxy may mask missing-attr errors with confusing messages
    • Emit a clearer error when neither wrapper nor inner has the attribute.
  • Output dynamic-dim selection for tuple outputs
    • Currently picks the first detected dynamic dim from inputs and applies to outputs. With VLMs, S may be at different axes (e.g., logits [B, S, V]).
    • Suggest: detect output token dim by comparing output shapes across bucket/probe, not by inheriting from inputs; cache it per-output tensor index.

Integration touchpoints verified:

  • compile_model.py
    • Collects and compiles nested GraphModules; passes full_model to backend. Good for capturing wrapper→inner kwargs via pre-hook. No change requested.
  • Sharding/Quantization hooks
    • Good prefixes for nested GMs; submodule-level hooks prevent ordering issues with fused-MoE unfusing. Kept as-is.

Proposed diffs:

  1. torch_cudagraph.py — safer fallback, warnings, robust truncation, always-clear num_tokens
# In PiecewiseCapturedGraph.forward(...)
-    return self.original_model(*args, **kwargs)
+    return self.original_model(*args, num_tokens=num_tokens, **kwargs)

# In PiecewiseCapturedGraph._reconstruct_output(...)
     try:
         return self._out_spec.unflatten(list(result))
-    except Exception:
-        return result
+    except Exception as e:
+        ad_logger.warning("_reconstruct_output: failed to unflatten output (%s); returning raw tuple", e)
+        return result

# In PiecewiseCapturedGraph._allocate_static_input_buffers(...)
-    if dyn_dim is not None:
-        self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
+    if dyn_dim is not None:
+        self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
+    else:
+        # Address is unstable but no dynamic dim — still buffer and copy full tensor
+        self._static_input_buffers[key] = (torch.empty_like(v1), None)

# In PiecewiseCapturedGraph.forward(...) where copying kwargs into static buffers
-    buf, dyn_dim = self._static_input_buffers[key]
-    if dyn_dim is None:
-        continue
-    bs = args_batched[0].shape[dyn_dim]
-    buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
-    kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)
+    buf, dyn_dim = self._static_input_buffers[key]
+    if dyn_dim is None:
+        buf.copy_(v_current)
+        kwargs_batched[key] = buf
+    else:
+        bs = args_batched[0].shape[dyn_dim]
+        buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
+        kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)

# In DualModeCapturedGraph.forward(...)
-    result = self.piecewise(*args, num_tokens=bucket, **kwargs)
-    ADPiecewiseRunner.set_current_num_tokens(None)
+    try:
+        result = self.piecewise(*args, num_tokens=bucket, **kwargs)
+    finally:
+        ADPiecewiseRunner.set_current_num_tokens(None)

# In DualModeCapturedGraph.__getattr__(...)
-    try:
-        return super().__getattr__(name)
-    except AttributeError:
-        return getattr(self.monolithic.model, name)
+    try:
+        return super().__getattr__(name)
+    except AttributeError:
+        try:
+            return getattr(self.monolithic.model, name)
+        except AttributeError:
+            raise AttributeError(f\"{type(self).__name__} and its inner model have no attribute '{name}'\")

# In DualModeCapturedGraph._truncate_output(...)
-    for d in range(v.ndim):
-        if v.shape[d] == bucket:
-            return v.narrow(d, 0, num_tokens)
-    return v
+    matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
+    if len(matching) != 1:
+        ad_logger.warning(\"_truncate_output: ambiguous dim for shape %s, bucket=%d; dims=%s\",
+                          tuple(v.shape), bucket, matching)
+    d = matching[0] if matching else -1
+    return v.narrow(d, 0, num_tokens) if d >= 0 else v
  1. Optional: make output-dim detection explicit at capture-time
  • If feasible in your flow, during capture_graph run with two buckets record, per-output tensor index, the dim that changed; store in self._output_dynamic_dims and use that in _truncate_output. I can sketch this if you want to adopt it now.
  1. utils/_graph.py — unwrap both torch_dist_all_gather and trtllm_dist_all_gather for lm_head node detection
  • Current change mentions unwrapping trtllm all_gather. If we ever fall back to torch-dist backend, also handle torch_dist_all_gather:
# In get_lm_head_node(...)
- if node.target is torch.ops.auto_deploy.trtllm_dist_all_gather:
+ if node.target in {
+     torch.ops.auto_deploy.trtllm_dist_all_gather.default,
+     torch.ops.auto_deploy.torch_dist_all_gather.default,
+ }:
      node = list(node.all_input_nodes)[0]

Why these now:

  • They are low-risk, localized changes that improve correctness and debuggability in edge cases we’ll hit with VLM piecewise replay (static-shaped kwarg buffers; ambiguous S-dim; exceptions).
  • They don’t alter capture semantics or compiled graphs — they only affect wrapper behavior and logging.

If you want, I can also add a tiny unit test to assert:

  • num_tokens is forwarded in fallback,
  • _truncate_output picks the intended dim for logits [B, S, V],
  • address-unstable static kwargs get buffered and don’t cause replay errors.

Happy to prepare those next.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo ">>> grep torch_cudagraph core spots"
rg -n --no-heading -C3 $'class PiecewiseCapturedGraph|def _reconstruct_output\\(|def _truncate_output\\(|class DualModeCapturedGraph|def __getattr__\\(|class CapturedGraph|dynamic_dims|set_current_num_tokens' tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py || true

echo ">>> show get_lm_head_node updates"
rg -n --no-heading -C4 'def get_lm_head_node|trtllm_dist_all_gather|torch_dist_all_gather' tensorrt_llm/_torch/auto_deploy/utils/_graph.py || true

Length of output: 7041


@nvchenghaoz Thanks — I dug deeper into the piecewise path and related integration points. Findings and concrete patches below.

What looks good

  • DualModeCapturedGraph now always clears ADPiecewiseRunner num_tokens via try/finally (Line 792-794).
  • CapturedGraph auto-detects per-input dynamic dims and uses narrow() at replay; this is solid.
  • The compile traversal that compiles nested GraphModules and passes full_model for inner-kwargs capture is sound.

Actionable fixes and improvements

  1. Bug: num_tokens is dropped on eager fallback in PiecewiseCapturedGraph.forward
  • Evidence: Line 650-651 returns original_model without forwarding num_tokens.
  • Fix: pass it through.

Patch:

# tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

class PiecewiseCapturedGraph(nn.Module):
    def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any:
        if self.split_gm is not None:
            self._copy_to_static_buffers(kwargs)
            ADPiecewiseRunner.set_current_num_tokens(num_tokens)
            try:
                result = self.split_gm(*args, **kwargs)
            finally:
                ADPiecewiseRunner.set_current_num_tokens(None)
            return self._reconstruct_output(result)
-       return self.original_model(*args, **kwargs)
+       return self.original_model(*args, num_tokens=num_tokens, **kwargs)
  1. Robustness: handle static-shaped but address-unstable kwargs in static buffer path
  • Today static buffers are allocated only when a dynamic dim is detected. If a kwarg tensor’s address changes across calls but its shape is static, replay can still be unsafe without buffering.
  • Change: allocate a buffer even when dyn_dim is None and do a full copy; use narrow only when a dyn_dim exists.

Patch (conceptual; adjust to your _copy_to_static_buffers/_allocate_static_input_buffers structure):

# When detecting instability:
if dyn_dim is not None:
    self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
else:
    # address unstable but static shape — still buffer it
    self._static_input_buffers[key] = (torch.empty_like(v1), None)

# When copying into buffers before replay:
buf, dyn_dim = self._static_input_buffers[key]
if dyn_dim is None:
    buf.copy_(v_current)
    kwargs_batched[key] = buf
else:
    bs = args_batched[0].shape[dyn_dim]
    buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
    kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)
  1. Safer truncation when bucket > real num_tokens
  • Current _truncate_output finds the first dim equal to bucket (Line 740+ docstring). This is ambiguous for square-ish shapes (e.g., [B, S, S] or [S, S, H]).
  • Minimal improvement: warn when multiple dims match and fall back gracefully.

Patch:

def _truncate_output(self, result: Any, num_tokens: int, bucket: int) -> Any:
    def cut(v: torch.Tensor) -> torch.Tensor:
        matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
        if len(matching) != 1:
            ad_logger.warning(
                "_truncate_output: ambiguous dim for shape %s, bucket=%d; dims=%s",
                tuple(v.shape), bucket, matching
            )
        d = matching[0] if matching else -1
        return v.narrow(d, 0, num_tokens) if d >= 0 else v
    if isinstance(result, tuple):
        return tuple(cut(v) if isinstance(v, torch.Tensor) else v for v in result)
    return cut(result) if isinstance(result, torch.Tensor) else result

Optional but recommended: during capture, compare each output tensor’s shape at bucket vs probe to record its own token-dim index (e.g., self._output_dynamic_dims[i]), then use that instead of the heuristic.

  1. Observability: log on output unflatten failure
  • If _reconstruct_output uses a try/except around unflatten, add a warning to surface mismatches. If it doesn’t currently catch, consider adding this guard.

Patch:

def _reconstruct_output(self, result: Any) -> Any:
    if not isinstance(result, tuple) or self._out_spec is None:
        return result
    try:
        return self._out_spec.unflatten(list(result))
    except Exception as e:
        ad_logger.warning("_reconstruct_output: failed to unflatten output (%s); returning raw tuple", e)
        return result
  1. Interop: unwrap both TRT-LLM and Torch all_gather around lm_head
  • Evidence: get_lm_head_node currently unwraps trtllm_dist_all_gather only (Lines 581-583). For torch backend we should also handle torch_dist_all_gather.

Patch:

# tensorrt_llm/_torch/auto_deploy/utils/_graph.py

from torch.fx import GraphModule, Node
import torch

def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Node:
    ...
-   if is_op(lm_head_node, torch.ops.auto_deploy.trtllm_dist_all_gather):
-       lm_head_node = lm_head_node.all_input_nodes[0]
+   if is_op(lm_head_node, {
+       torch.ops.auto_deploy.trtllm_dist_all_gather,
+       torch.ops.auto_deploy.torch_dist_all_gather,
+   }):
+       lm_head_node = lm_head_node.all_input_nodes[0]
    return lm_head_node

Nice-to-have

  • DualModeCapturedGraph.getattr: if both wrapper and inner model lack the attribute, raise a clear message. If you want, I can supply a tiny tweak to avoid confusing AttributeError chaining.

Would you like me to stage these as a small follow-up patch and add a couple of focused unit tests (num_tokens passthrough and truncation-dim disambiguation)? If yes, I can prepare a plan next.

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

Close this one since it contains other work. This branch is just to check whether the change fixes the Qwen VLMs. Will create a new PR to have Piecewise CG changes only.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants